{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "f1ee0615-df06-4063-b037-b1a8ad754d58",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:17:49.423438Z",
     "iopub.status.busy": "2022-06-15T09:17:49.422846Z",
     "iopub.status.idle": "2022-06-15T09:17:49.434276Z",
     "shell.execute_reply": "2022-06-15T09:17:49.433556Z",
     "shell.execute_reply.started": "2022-06-15T09:17:49.423389Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "from datetime import datetime\n",
    "import time \n",
    "import torch\n",
    "import os\n",
    "from transformers import BertModel,BertConfig,BertModel,BertTokenizerFast,get_cosine_schedule_with_warmup,BertForMaskedLM\n",
    "import pandas  as pd\n",
    "import torch \n",
    "import torch.nn as nn \n",
    "import torch.optim as optim \n",
    "import torch.utils.data as Data\n",
    "from torch.utils.tensorboard import SummaryWriter \n",
    " \n",
    "# hyperparameters \n",
    "EPOCH=200\n",
    "RANDOM_SEED=2022 \n",
    "TRAIN_BATCH_SIZE=32  #小批训练， 批大小增大时需要提升学习率  https://zhuanlan.zhihu.com/p/413656738\n",
    "TEST_BATCH_SIZE=96   #大批测试\n",
    "EVAL_PERIOD=20\n",
    "MODEL_NAME=\"bert-base-uncased\"  # bert-base-chinese\n",
    "DATA_PATH=\"dataset/twitter_sentiment/\"\n",
    "MASK_POS=3  # \"it was [mask]\" 中 [mask] 位置\n",
    "train_file=\"twitter-2013train-A.tsv\"\n",
    "dev_file=\"twitter-2013dev-A.tsv\"\n",
    "test_file=\"twitter-2013test-A.tsv\"\n",
    " \n",
    "device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "040677f2-e1af-4afc-a7e2-76e9100a6d7f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:17:49.675517Z",
     "iopub.status.busy": "2022-06-15T09:17:49.674969Z",
     "iopub.status.idle": "2022-06-15T09:17:49.684274Z",
     "shell.execute_reply": "2022-06-15T09:17:49.683798Z",
     "shell.execute_reply.started": "2022-06-15T09:17:49.675469Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "writer = SummaryWriter('./tb_log')\n",
    "\n",
    "pd.options.display.max_columns = None\n",
    "pd.options.display.max_rows = None\n",
    "\n",
    "prefix = 'It was [mask]. '\n",
    " \n",
    "class Bert_Model(nn.Module):\n",
    "    def __init__(self,  bert_path ,config_file ):\n",
    "        super(Bert_Model, self).__init__()\n",
    "        self.bert = BertForMaskedLM.from_pretrained(bert_path,config=config_file)  # 加载预训练模型权重\n",
    " \n",
    " \n",
    "    def forward(self, input_ids, attention_mask, token_type_ids):\n",
    "        outputs = self.bert(input_ids, attention_mask, token_type_ids) #masked LM 输出的是 mask的值 对应的ids的概率 ，输出 会是词表大小，里面是概率 \n",
    "        logit = outputs[0]  # 池化后的输出 [bs, config.hidden_size]\n",
    "\n",
    "        return logit "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "02d48cd9-9751-4574-a021-3867f920c439",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:17:49.945689Z",
     "iopub.status.busy": "2022-06-15T09:17:49.945085Z",
     "iopub.status.idle": "2022-06-15T09:17:59.611829Z",
     "shell.execute_reply": "2022-06-15T09:17:59.611306Z",
     "shell.execute_reply.started": "2022-06-15T09:17:49.945637Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "#构建数据集\n",
    "class MyDataSet(Data.Dataset):\n",
    "    def __init__(self, sen , mask , typ ,label ):\n",
    "        super(MyDataSet, self).__init__()\n",
    "        self.sen = torch.tensor(sen,dtype=torch.long)\n",
    "        self.mask = torch.tensor(mask,dtype=torch.long)\n",
    "        self.typ =torch.tensor( typ,dtype=torch.long)\n",
    "        self.label = torch.tensor(label,dtype=torch.long)\n",
    " \n",
    "    def __len__(self):\n",
    "        return self.sen.shape[0]\n",
    " \n",
    "    def __getitem__(self, idx):\n",
    "        return self.sen[idx], self.mask[idx],self.typ[idx],self.label[idx]\n",
    "#load  data\n",
    "   \n",
    "def load_data(tsvpath):\n",
    "    data=pd.read_csv(tsvpath,sep=\"\\t\",header=None,names=[\"sn\",\"polarity\",\"text\"])\n",
    "    data=data[data[\"polarity\"] != \"neutral\"]\n",
    "    yy=data[\"polarity\"].replace({\"negative\":0,\"positive\":1,\"neutral\":2})\n",
    "    return data.values[:,2:3].tolist(),yy.tolist() #data.values[:,1:2].tolist()\n",
    " \n",
    "tokenizer=BertTokenizerFast.from_pretrained(MODEL_NAME)\n",
    "config=BertConfig.from_pretrained(MODEL_NAME)\n",
    "model=Bert_Model(bert_path=MODEL_NAME,config_file=config).to(device)\n",
    " \n",
    "pos_id=tokenizer.convert_tokens_to_ids(\"good\") #9005\n",
    "neg_id=tokenizer.convert_tokens_to_ids(\"bad\")  #12139"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "7b31ff2d-85b4-443a-9ecb-39ea5cb58b79",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:18:02.424978Z",
     "iopub.status.busy": "2022-06-15T09:18:02.424690Z",
     "iopub.status.idle": "2022-06-15T09:18:03.576634Z",
     "shell.execute_reply": "2022-06-15T09:18:03.576079Z",
     "shell.execute_reply.started": "2022-06-15T09:18:02.424956Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# get the data and label \n",
    "def ProcessData(filepath):\n",
    "    x_train,y_train=load_data(DATA_PATH+os.sep+filepath)\n",
    "    #x_train,x_test,y_train,y_test=train_test_split(StrongData,StrongLabel,test_size=0.3, random_state=42)\n",
    " \n",
    "    Inputid=[]\n",
    "    Labelid=[]\n",
    "    typeid=[]\n",
    "    attenmask=[]\n",
    " \n",
    "    for i in range(len(x_train)):\n",
    " \n",
    "        text_ = prefix+x_train[i][0]\n",
    " \n",
    "        encode_dict = tokenizer.encode_plus(text_,max_length=60,padding=\"max_length\",truncation=True)\n",
    "        input_ids=encode_dict[\"input_ids\"]\n",
    "        type_ids=encode_dict[\"token_type_ids\"]\n",
    "        atten_mask=encode_dict[\"attention_mask\"]\n",
    "        labelid,inputid= input_ids[:],input_ids[:]\n",
    "        if y_train[i] == 0:\n",
    "            labelid[MASK_POS] = neg_id\n",
    "            labelid[:MASK_POS] = [-1]* len(labelid[:MASK_POS]) \n",
    "            labelid[MASK_POS+1:] = [-1] * len(labelid[MASK_POS+1:])\n",
    "            inputid[MASK_POS] = tokenizer.mask_token_id\n",
    "        else:\n",
    "            labelid[MASK_POS] = pos_id\n",
    "            labelid[:MASK_POS] = [-1]* len(labelid[:MASK_POS]) \n",
    "            labelid[MASK_POS+1:] = [-1] * len(labelid[MASK_POS+1:])\n",
    "            inputid[MASK_POS] = tokenizer.mask_token_id\n",
    " \n",
    "        Labelid.append(labelid)\n",
    "        Inputid.append(inputid)\n",
    "        typeid.append(type_ids)\n",
    "        attenmask.append(atten_mask)\n",
    " \n",
    "    return Inputid,Labelid,typeid,attenmask\n",
    " \n",
    "\n",
    "Inputid_train,Labelid_train,typeids_train,inputnmask_train=ProcessData(train_file)\n",
    "Inputid_dev,Labelid_dev,typeids_dev,inputnmask_dev=ProcessData(dev_file)\n",
    "Inputid_test,Labelid_test,typeids_test,inputnmask_test=ProcessData(test_file)\n",
    "\n",
    "train_dataset = Data.DataLoader(MyDataSet(Inputid_train,  inputnmask_train , typeids_train , Labelid_train), TRAIN_BATCH_SIZE, True)\n",
    "valid_dataset = Data.DataLoader(MyDataSet(Inputid_dev,  inputnmask_dev , typeids_dev , Labelid_dev), TRAIN_BATCH_SIZE, True)\n",
    "test_dataset = Data.DataLoader(MyDataSet(Inputid_test,  inputnmask_test , typeids_test , Labelid_test), TEST_BATCH_SIZE, True)\n",
    " \n",
    "train_data_num=len(Inputid_train)\n",
    "test_data_num=len(Inputid_test)\n",
    "#print(\"hello!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "1357e972-fcd8-4db8-b14c-545ba9038c21",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:18:03.578018Z",
     "iopub.status.busy": "2022-06-15T09:18:03.577622Z",
     "iopub.status.idle": "2022-06-15T09:18:03.582423Z",
     "shell.execute_reply": "2022-06-15T09:18:03.581912Z",
     "shell.execute_reply.started": "2022-06-15T09:18:03.577988Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(160, 5098)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_dataset), len(train_dataset.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "3e2cb9f3-83ad-4a6e-b58e-66b2c80aa18d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:18:05.424357Z",
     "iopub.status.busy": "2022-06-15T09:18:05.423768Z",
     "iopub.status.idle": "2022-06-15T09:24:40.107598Z",
     "shell.execute_reply": "2022-06-15T09:24:40.107084Z",
     "shell.execute_reply.started": "2022-06-15T09:18:05.424306Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "正在训练中。。。\n",
      "Epoch 0001 | Step 000020/000160 | Loss 1.2969 | Time 4\n",
      "Epoch 0001 | Step 000040/000160 | Loss 0.8165 | Time 9\n",
      "Epoch 0001 | Step 000060/000160 | Loss 0.6466 | Time 13\n",
      "Epoch 0001 | Step 000080/000160 | Loss 0.5754 | Time 17\n",
      "Epoch 0001 | Step 000100/000160 | Loss 0.5138 | Time 22\n",
      "Epoch 0001 | Step 000120/000160 | Loss 0.4688 | Time 26\n",
      "Epoch 0001 | Step 000140/000160 | Loss 0.4432 | Time 30\n",
      "Epoch 0001 | Step 000160/000160 | Loss 0.4131 | Time 35\n",
      "epoch 1, train_loss 0.4131312052253634,  train_acc 0.8587681443703413 , eval_loss 0.2061903913590041 ,acc_test 0.9228121927236972\n",
      "epoch 1 duration: 39.09117555618286\n",
      "Epoch 0002 | Step 000020/000160 | Loss 0.1164 | Time 4\n",
      "Epoch 0002 | Step 000040/000160 | Loss 0.1335 | Time 9\n",
      "Epoch 0002 | Step 000060/000160 | Loss 0.1306 | Time 13\n",
      "Epoch 0002 | Step 000080/000160 | Loss 0.1292 | Time 17\n",
      "Epoch 0002 | Step 000100/000160 | Loss 0.1295 | Time 22\n",
      "Epoch 0002 | Step 000120/000160 | Loss 0.1349 | Time 26\n",
      "Epoch 0002 | Step 000140/000160 | Loss 0.1368 | Time 31\n",
      "Epoch 0002 | Step 000160/000160 | Loss 0.1419 | Time 35\n",
      "epoch 2, train_loss 0.14190725854132324,  train_acc 0.9454688112985484 , eval_loss 0.21056864749301563 ,acc_test 0.9183874139626352\n",
      "epoch 2 duration: 39.20729088783264\n",
      "Epoch 0003 | Step 000020/000160 | Loss 0.0627 | Time 4\n",
      "Epoch 0003 | Step 000040/000160 | Loss 0.0501 | Time 9\n",
      "Epoch 0003 | Step 000060/000160 | Loss 0.0452 | Time 13\n",
      "Epoch 0003 | Step 000080/000160 | Loss 0.0585 | Time 18\n",
      "Epoch 0003 | Step 000100/000160 | Loss 0.0600 | Time 22\n",
      "Epoch 0003 | Step 000120/000160 | Loss 0.0582 | Time 26\n",
      "Epoch 0003 | Step 000140/000160 | Loss 0.0591 | Time 31\n",
      "Epoch 0003 | Step 000160/000160 | Loss 0.0605 | Time 35\n",
      "epoch 3, train_loss 0.06045952029016917,  train_acc 0.9801883091408395 , eval_loss 0.3267136504026977 ,acc_test 0.9095378564405113\n",
      "epoch 3 duration: 39.35101246833801\n",
      "Epoch 0004 | Step 000020/000160 | Loss 0.0440 | Time 4\n",
      "Epoch 0004 | Step 000040/000160 | Loss 0.0423 | Time 9\n",
      "Epoch 0004 | Step 000060/000160 | Loss 0.0363 | Time 13\n",
      "Epoch 0004 | Step 000080/000160 | Loss 0.0348 | Time 18\n",
      "Epoch 0004 | Step 000100/000160 | Loss 0.0349 | Time 22\n",
      "Epoch 0004 | Step 000120/000160 | Loss 0.0319 | Time 26\n",
      "Epoch 0004 | Step 000140/000160 | Loss 0.0310 | Time 31\n",
      "Epoch 0004 | Step 000160/000160 | Loss 0.0327 | Time 35\n",
      "epoch 4, train_loss 0.032715786315930015,  train_acc 0.989799921537858 , eval_loss 0.3345567044399848 ,acc_test 0.9139626352015733\n",
      "epoch 4 duration: 39.44218349456787\n",
      "Epoch 0005 | Step 000020/000160 | Loss 0.0184 | Time 4\n",
      "Epoch 0005 | Step 000040/000160 | Loss 0.0146 | Time 9\n",
      "Epoch 0005 | Step 000060/000160 | Loss 0.0123 | Time 13\n",
      "Epoch 0005 | Step 000080/000160 | Loss 0.0104 | Time 18\n",
      "Epoch 0005 | Step 000100/000160 | Loss 0.0134 | Time 22\n",
      "Epoch 0005 | Step 000120/000160 | Loss 0.0126 | Time 26\n",
      "Epoch 0005 | Step 000140/000160 | Loss 0.0166 | Time 31\n",
      "Epoch 0005 | Step 000160/000160 | Loss 0.0197 | Time 35\n",
      "epoch 5, train_loss 0.019662532032089074,  train_acc 0.9954884268340526 , eval_loss 0.3952314137396487 ,acc_test 0.9031465093411996\n",
      "epoch 5 duration: 39.51073956489563\n",
      "Epoch 0006 | Step 000020/000160 | Loss 0.0347 | Time 4\n",
      "Epoch 0006 | Step 000040/000160 | Loss 0.0296 | Time 9\n",
      "Epoch 0006 | Step 000060/000160 | Loss 0.0291 | Time 13\n",
      "Epoch 0006 | Step 000080/000160 | Loss 0.0255 | Time 18\n",
      "Epoch 0006 | Step 000100/000160 | Loss 0.0246 | Time 22\n",
      "Epoch 0006 | Step 000120/000160 | Loss 0.0264 | Time 26\n",
      "Epoch 0006 | Step 000140/000160 | Loss 0.0311 | Time 31\n",
      "Epoch 0006 | Step 000160/000160 | Loss 0.0322 | Time 35\n",
      "epoch 6, train_loss 0.03217803146999358,  train_acc 0.989799921537858 , eval_loss 0.48638979413292627 ,acc_test 0.880039331366765\n",
      "epoch 6 duration: 39.56922364234924\n",
      "Epoch 0007 | Step 000020/000160 | Loss 0.0429 | Time 4\n",
      "Epoch 0007 | Step 000040/000160 | Loss 0.0315 | Time 9\n",
      "Epoch 0007 | Step 000060/000160 | Loss 0.0305 | Time 13\n",
      "Epoch 0007 | Step 000080/000160 | Loss 0.0263 | Time 18\n",
      "Epoch 0007 | Step 000100/000160 | Loss 0.0239 | Time 22\n",
      "Epoch 0007 | Step 000120/000160 | Loss 0.0209 | Time 27\n",
      "Epoch 0007 | Step 000140/000160 | Loss 0.0186 | Time 31\n",
      "Epoch 0007 | Step 000160/000160 | Loss 0.0193 | Time 35\n",
      "epoch 7, train_loss 0.01925387882456562,  train_acc 0.9927422518634759 , eval_loss 0.43838024681264703 ,acc_test 0.9169124877089478\n",
      "epoch 7 duration: 39.597825050354004\n",
      "Epoch 0008 | Step 000020/000160 | Loss 0.0078 | Time 4\n",
      "Epoch 0008 | Step 000040/000160 | Loss 0.0119 | Time 9\n",
      "Epoch 0008 | Step 000060/000160 | Loss 0.0126 | Time 13\n",
      "Epoch 0008 | Step 000080/000160 | Loss 0.0131 | Time 18\n",
      "Epoch 0008 | Step 000100/000160 | Loss 0.0129 | Time 22\n",
      "Epoch 0008 | Step 000120/000160 | Loss 0.0148 | Time 27\n",
      "Epoch 0008 | Step 000140/000160 | Loss 0.0135 | Time 31\n",
      "Epoch 0008 | Step 000160/000160 | Loss 0.0145 | Time 35\n",
      "epoch 8, train_loss 0.01446246668265303,  train_acc 0.9950961161239702 , eval_loss 0.41422444181146356 ,acc_test 0.9139626352015733\n",
      "epoch 8 duration: 39.61558246612549\n",
      "Epoch 0009 | Step 000020/000160 | Loss 0.0167 | Time 4\n",
      "Epoch 0009 | Step 000040/000160 | Loss 0.0131 | Time 9\n",
      "Epoch 0009 | Step 000060/000160 | Loss 0.0095 | Time 13\n",
      "Epoch 0009 | Step 000080/000160 | Loss 0.0096 | Time 18\n",
      "Epoch 0009 | Step 000100/000160 | Loss 0.0125 | Time 22\n",
      "Epoch 0009 | Step 000120/000160 | Loss 0.0130 | Time 27\n",
      "Epoch 0009 | Step 000140/000160 | Loss 0.0129 | Time 31\n",
      "Epoch 0009 | Step 000160/000160 | Loss 0.0133 | Time 35\n",
      "epoch 9, train_loss 0.01328287663136507,  train_acc 0.994899960768929 , eval_loss 0.4532659338279204 ,acc_test 0.9060963618485742\n",
      "epoch 9 duration: 39.620912313461304\n",
      "Epoch 0010 | Step 000020/000160 | Loss 0.0114 | Time 4\n",
      "Epoch 0010 | Step 000040/000160 | Loss 0.0090 | Time 9\n",
      "Epoch 0010 | Step 000060/000160 | Loss 0.0124 | Time 13\n",
      "Epoch 0010 | Step 000080/000160 | Loss 0.0128 | Time 18\n",
      "Epoch 0010 | Step 000100/000160 | Loss 0.0149 | Time 22\n",
      "Epoch 0010 | Step 000120/000160 | Loss 0.0170 | Time 27\n",
      "Epoch 0010 | Step 000140/000160 | Loss 0.0163 | Time 31\n",
      "Epoch 0010 | Step 000160/000160 | Loss 0.0157 | Time 35\n",
      "epoch 10, train_loss 0.015654152394563425,  train_acc 0.9950961161239702 , eval_loss 0.3837622517127205 ,acc_test 0.911504424778761\n",
      "epoch 10 duration: 39.6552791595459\n",
      "total training time:  394.66122460365295\n"
     ]
    }
   ],
   "source": [
    "optimizer = AdamW(model.parameters(),lr=2e-5,weight_decay=1e-4)  #使用Adam优化器\n",
    "loss_func = nn.CrossEntropyLoss(ignore_index=-1)\n",
    "EPOCH = 200\n",
    "# schedule = get_cosine_schedule_with_warmup(optimizer,num_warmup_steps=len(train_dataset),num_training_steps=EPOCH*len(train_dataset))\n",
    "print(\"正在训练中。。。\")\n",
    "totaltime=0\n",
    "for epoch in range(10):\n",
    " \n",
    "    starttime_train=datetime.now()\n",
    " \n",
    "    start =time.time()\n",
    "    correct=0\n",
    "    train_loss_sum=0\n",
    "    model.train()\n",
    " \n",
    "    for idx,(ids,att_mask,type,y) in enumerate(train_dataset):\n",
    "        ids,att_mask,type,y = ids.to(device),att_mask.to(device),type.to(device),y.to(device)\n",
    "        out_train = model(ids,att_mask,type)\n",
    "       #print(out_train.view(-1, tokenizer.vocab_size).shape, y.view(-1).shape)\n",
    "        loss = loss_func(out_train.view(-1, tokenizer.vocab_size),y.view(-1))\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        # schedule.step()\n",
    "        train_loss_sum += loss.item()\n",
    "       \n",
    "        if( idx+1)% EVAL_PERIOD == 0:\n",
    "            print(\"Epoch {:04d} | Step {:06d}/{:06d} | Loss {:.4f} | Time {:.0f}\".format(\n",
    "                epoch + 1, idx + 1, len(train_dataset), train_loss_sum / (idx + 1), time.time() - start))\n",
    "            writer.add_scalar('loss/train_loss', train_loss_sum / (idx + 1), epoch)\n",
    " \n",
    "        truelabel=y[:,MASK_POS]\n",
    "        out_train_mask=out_train[:,MASK_POS,:]\n",
    "        predicted=torch.max(out_train_mask,1)[1]\n",
    "        correct += (predicted == truelabel).sum()\n",
    "        correct = float(correct)\n",
    "    \n",
    "    acc =float(correct /train_data_num)\n",
    " \n",
    "    eval_loss_sum=0.0\n",
    "    model.eval()\n",
    "    correct_test=0\n",
    "    with torch.no_grad():\n",
    "        for ids, att, tpe, y in test_dataset:\n",
    "            ids, att, tpe, y = ids.to(device), att.to(device), tpe.to(device), y.to(device)\n",
    "            out_test = model(ids , att , tpe)\n",
    "            loss_eval = loss_func(out_test.view(-1, tokenizer.vocab_size), y.view(-1))\n",
    "            eval_loss_sum += loss_eval.item()\n",
    "            ttruelabel = y[:, MASK_POS]\n",
    "            tout_train_mask = out_test[:, MASK_POS, :]\n",
    "            predicted_test = torch.max(tout_train_mask.data, 1)[1]\n",
    "            correct_test += (predicted_test == ttruelabel).sum()\n",
    "            correct_test = float(correct_test)\n",
    "    acc_test = float(correct_test / test_data_num)\n",
    " \n",
    "    if epoch % 1 == 0:\n",
    "        out = (\"epoch {}, train_loss {},  train_acc {} , eval_loss {} ,acc_test {}\"\n",
    "               .format(epoch + 1, train_loss_sum / (len(train_dataset)), acc, eval_loss_sum / (len(test_dataset)),\n",
    "                acc_test))\n",
    "        writer.add_scalar('loss/test_loss', train_loss_sum / (idx + 1), epoch)\n",
    "        print(out)\n",
    "    end=time.time()\n",
    " \n",
    "    print(\"epoch {} duration:\".format(epoch+1),end-start)\n",
    "    totaltime+=end-start\n",
    " \n",
    "print(\"total training time: \",totaltime)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "66504822-4791-46fd-9e3f-5df0681156f6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:06:33.421138Z",
     "iopub.status.busy": "2022-06-15T09:06:33.420546Z",
     "iopub.status.idle": "2022-06-15T09:06:33.458465Z",
     "shell.execute_reply": "2022-06-15T09:06:33.457367Z",
     "shell.execute_reply.started": "2022-06-15T09:06:33.421088Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_df = pd.read_csv(DATA_PATH + train_file, sep=\"\\t\",header=None,names=[\"sn\",\"polarity\",\"text\"])\n",
    "train_df=train_df[train_df[\"polarity\"] != \"neutral\"]\n",
    "train_df['polarity'] = train_df['polarity'].map({'negative': 0, 'positive': 1, 'neutral': 1})\n",
    "\n",
    "val_df = pd.read_csv(DATA_PATH + dev_file, sep=\"\\t\",header=None,names=[\"sn\",\"polarity\",\"text\"])\n",
    "val_df=val_df[val_df[\"polarity\"] != \"neutral\"]\n",
    "val_df['polarity'] = val_df['polarity'].map({'negative': 0, 'positive': 1, 'neutral': 1})\n",
    "\n",
    "test_df = pd.read_csv(DATA_PATH + test_file, sep=\"\\t\",header=None,names=[\"sn\",\"polarity\",\"text\"])\n",
    "test_df=test_df[test_df[\"polarity\"] != \"neutral\"]\n",
    "test_df['polarity'] = test_df['polarity'].map({'negative': 0, 'positive': 1, 'neutral': 1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "2aedff83-3fc2-4953-9127-7ae9d0571d2d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:06:33.459590Z",
     "iopub.status.busy": "2022-06-15T09:06:33.459365Z",
     "iopub.status.idle": "2022-06-15T09:06:33.796340Z",
     "shell.execute_reply": "2022-06-15T09:06:33.795025Z",
     "shell.execute_reply.started": "2022-06-15T09:06:33.459573Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5098, 3)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "0a229d73-0dfe-46a5-86a5-968d3c044413",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:06:33.797824Z",
     "iopub.status.busy": "2022-06-15T09:06:33.797617Z",
     "iopub.status.idle": "2022-06-15T09:06:34.492053Z",
     "shell.execute_reply": "2022-06-15T09:06:34.491464Z",
     "shell.execute_reply.started": "2022-06-15T09:06:33.797805Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_encoding = tokenizer(train_df['text'].tolist(), truncation=True, padding=True, max_length=55)\n",
    "test_encoding = tokenizer(test_df['text'].tolist(), truncation=True, padding=True, max_length=55)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "dffc78da-0a90-426a-b7c0-75e0f9f72ef4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:06:34.493217Z",
     "iopub.status.busy": "2022-06-15T09:06:34.492940Z",
     "iopub.status.idle": "2022-06-15T09:06:34.498014Z",
     "shell.execute_reply": "2022-06-15T09:06:34.497531Z",
     "shell.execute_reply.started": "2022-06-15T09:06:34.493200Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 数据集读取\n",
    "class NewsDataset(Data.Dataset):\n",
    "    def __init__(self, encodings, labels):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "    \n",
    "    # 读取单个样本\n",
    "    def __getitem__(self, idx):\n",
    "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
    "        item['labels'] = torch.tensor(int(self.labels[idx]))\n",
    "        return item\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "train_dataset = NewsDataset(train_encoding, train_df['polarity'].tolist())\n",
    "test_dataset = NewsDataset(test_encoding, test_df['polarity'].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "b654f92c-f316-4751-ac28-a3d8060ee1cd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:06:34.498890Z",
     "iopub.status.busy": "2022-06-15T09:06:34.498730Z",
     "iopub.status.idle": "2022-06-15T09:06:37.482706Z",
     "shell.execute_reply": "2022-06-15T09:06:37.482204Z",
     "shell.execute_reply.started": "2022-06-15T09:06:34.498875Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup\n",
    "model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# 单个读取到批量读取\n",
    "train_loader = Data.DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "test_dataloader = Data.DataLoader(test_dataset, batch_size=32, shuffle=True)\n",
    "\n",
    "optim = AdamW(model.parameters(),lr=2e-5,weight_decay=1e-4)  #使用Adam优化器\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "59da691b-d9aa-4e6f-be03-dfd782b52dc9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-15T09:06:37.484013Z",
     "iopub.status.busy": "2022-06-15T09:06:37.483838Z",
     "iopub.status.idle": "2022-06-15T09:08:43.967065Z",
     "shell.execute_reply": "2022-06-15T09:08:43.966525Z",
     "shell.execute_reply.started": "2022-06-15T09:06:37.483997Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------Epoch: 0 ----------------\n",
      "epoth: 0, iter_num: 100, loss: 0.3529, 62.50%\n",
      "Epoch: 0, Average training loss: 0.3573\n",
      "Accuracy: 0.9156\n",
      "Average testing loss: 0.2111\n",
      "-------------------------------\n",
      "------------Epoch: 1 ----------------\n",
      "epoth: 1, iter_num: 100, loss: 0.1082, 62.50%\n",
      "Epoch: 1, Average training loss: 0.1582\n",
      "Accuracy: 0.8989\n",
      "Average testing loss: 0.2469\n",
      "-------------------------------\n",
      "------------Epoch: 2 ----------------\n",
      "epoth: 2, iter_num: 100, loss: 0.1826, 62.50%\n",
      "Epoch: 2, Average training loss: 0.0745\n",
      "Accuracy: 0.9192\n",
      "Average testing loss: 0.2933\n",
      "-------------------------------\n",
      "------------Epoch: 3 ----------------\n",
      "epoth: 3, iter_num: 100, loss: 0.0127, 62.50%\n",
      "Epoch: 3, Average training loss: 0.0375\n",
      "Accuracy: 0.9264\n",
      "Average testing loss: 0.3523\n",
      "-------------------------------\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "# 精度计算\n",
    "def flat_accuracy(preds, labels):\n",
    "    pred_flat = np.argmax(preds, axis=1).flatten()\n",
    "    labels_flat = labels.flatten()\n",
    "    return np.sum(pred_flat == labels_flat) / len(labels_flat)\n",
    "\n",
    "# 训练函数\n",
    "def train():\n",
    "    model.train()\n",
    "    total_train_loss = 0\n",
    "    iter_num = 0\n",
    "    total_iter = len(train_loader)\n",
    "    for batch in train_loader:\n",
    "        # 正向传播\n",
    "        optim.zero_grad()\n",
    "        \n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device)\n",
    "        labels = batch['labels'].to(device)\n",
    "        \n",
    "        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs[0]\n",
    "        total_train_loss += loss.item()\n",
    "        \n",
    "        # 反向梯度信息\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "        \n",
    "        # 参数更新\n",
    "        optim.step()\n",
    "\n",
    "        iter_num += 1\n",
    "        if(iter_num % 100==0):\n",
    "            print(\"epoth: %d, iter_num: %d, loss: %.4f, %.2f%%\" % (epoch, iter_num, loss.item(), iter_num/total_iter*100))\n",
    "        \n",
    "    print(\"Epoch: %d, Average training loss: %.4f\"%(epoch, total_train_loss/len(train_loader)))\n",
    "    \n",
    "def validation():\n",
    "    model.eval()\n",
    "    total_eval_accuracy = 0\n",
    "    total_eval_loss = 0\n",
    "    for batch in test_dataloader:\n",
    "        with torch.no_grad():\n",
    "            # 正常传播\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "            labels = batch['labels'].to(device)\n",
    "            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        \n",
    "        loss = outputs[0]\n",
    "        logits = outputs[1]\n",
    "\n",
    "        total_eval_loss += loss.item()\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        label_ids = labels.to('cpu').numpy()\n",
    "        total_eval_accuracy += flat_accuracy(logits, label_ids)\n",
    "        \n",
    "    avg_val_accuracy = total_eval_accuracy / len(test_dataloader)\n",
    "    print(\"Accuracy: %.4f\" % (avg_val_accuracy))\n",
    "    print(\"Average testing loss: %.4f\"%(total_eval_loss/len(test_dataloader)))\n",
    "    print(\"-------------------------------\")\n",
    "    \n",
    "\n",
    "for epoch in range(4):\n",
    "    print(\"------------Epoch: %d ----------------\" % epoch)\n",
    "    train()\n",
    "    validation()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70abb67d-b607-49fd-9cad-303254b4a711",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
